dec_inputs = torch.tensor([[6, 1, 2, 3, 5],
                                [6, 4, 3, 0, 0]])
# 前瞻掩码
dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs).to(device)
# [batch_size, tgt_len, tgt_len]
print(dec_self_attn_pad_mask)  
